// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstdlib>

#include "glog/logging.h"
#include "gtest/gtest.h"
#include "third_party/ac_datatypes/include/ac_std_float.h"

// Fill array with random numbers between -100 and 100
void load_array_with_random(ac::bfloat16* array, size_t size) {
  for (size_t i = 0; i < size; i++) {
    array[i] = static_cast<float>(-100 + (random() / RAND_MAX + (200)));
  }
}

void compare_arrays(ac::bfloat16* arrayA, ac::bfloat16* arrayB, size_t size) {
  // buckets of <0.001, <0.01, <0.1, <1, >1
  int diff_buckets[5] = {0, 0, 0, 0, 0};

  for (int index = 0; index < size; index++) {
    float diff = abs(((float)arrayA[index] - (float)arrayB[index]));

    std::cout << arrayA[index] << " vs " << arrayB[index] << std::endl;

    if (diff < 0.001) {
      diff_buckets[0]++;
    }
    if (diff < 0.01) {
      diff_buckets[1]++;
    }
    if (diff < 0.1) {
      diff_buckets[2]++;
    }
    if (diff < 1) {
      diff_buckets[3]++;
    } else {
      diff_buckets[4]++;
    }
  }

  std::cout << "Differences Count:" << std::endl;
  std::cout << "< 0.001: " << diff_buckets[0] << "("
            << (float)diff_buckets[0] / (size)*100.0 << "%)" << std::endl;
  std::cout << "< 0.01: " << diff_buckets[1] << "("
            << (float)diff_buckets[1] / (size)*100.0 << "%)" << std::endl;
  std::cout << "< 0.1: " << diff_buckets[2] << "("
            << (float)diff_buckets[2] / (size)*100.0 << "%)" << std::endl;
  std::cout << "< 1: " << diff_buckets[3] << "("
            << (float)diff_buckets[3] / (size)*100.0 << "%)" << std::endl;
  std::cout << "> 1: " << diff_buckets[4] << "("
            << (float)diff_buckets[4] / (size)*100.0 << "%)" << std::endl;
  std::cout << std::endl;

#ifdef GTEST
  // >97% should be within 0.1
  EXPECT_GT((float)diff_buckets[2] / size, 0.97);

  // 100% should be within 1
  EXPECT_EQ(diff_buckets[3], size);
#endif
}